import argparse
import torch
import torch.nn as nn
import utils.datasets as dl
from utils.load_trained_model import load_model
import pathlib
import matplotlib as mpl
from tqdm import tqdm, trange
import numpy as np
from multiprocessing import Pool
import time
import ssl_utils as ssl
from torchvision import datasets, transforms
from torchvision.utils import save_image

torch.backends.cudnn.benchmark = True

mpl.use('Agg')
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser(description='Parse arguments.', prefix_chars='-')

parser.add_argument('--gpu','--list', nargs='+', default=[0],
                    help='GPU indices, if more than 1 parallel modules will be called')
hps = parser.parse_args()
if len(hps.gpu)==0:
    device_ids = None
    device = torch.device('cpu')
    print('Warning! Computing on CPU')
elif len(hps.gpu)==1:
    device_ids = [int(hps.gpu[0])]
    device = torch.device('cuda:' + str(hps.gpu[0]))
else:
    device_ids = [int(i) for i in hps.gpu]
    device = torch.device('cuda:' + str(device_ids[0]))

dataset = 'cifar100'
svhn_extra = True

if dataset == 'svhn' and svhn_extra:
    path = f'DatasetClassifications/80MSVHNExtra/'
elif dataset == 'svhn' and not svhn_extra:
    path = f'DatasetClassifications/80MSVHN/'
elif dataset == 'cifar10':
    path = f'DatasetClassifications/80MCifar10/'
elif dataset == 'cifar100':
    path = f'DatasetClassifications/80MCifar100/'
else:
    raise ValueError('Dataset not supported')

pathlib.Path(path).mkdir(parents=True, exist_ok=True)

model_descriptions = [
    #('ResNet50', 'CEDA_16-02-2021_09:31:49', 'best', None, False),
    ('ResNet50', 'CEDA_02-10-2021_06:43:10', 'best', None, False),
    ('ResNet50', 'CEDA_02-10-2021_06:43:14', 'best', None, False),
    ('ResNet50', 'CEDA_02-10-2021_08:25:01', 'best', None, False),
    ('ResNet50', 'CEDA_02-10-2021_08:25:04', 'best', None, False),
    #('WideResNet40x10', 'CEDA_05-12-2020_11:44:53', 'best', None, False),
    #('WideResNet28x2', 'CEDA_03-08-2021_15:37:14', 'best', None, False),
    #('WideResNet28x2', 'plain_24-03-2021_02:14:53', 'best', None, False),
    #('WideResNet70x16', 'CEDA_09-12-2020_18:47:56', 'final_swa', None, False),
    #('shakedrop_pyramid272', 'plain_31-10-2020_08:12:59', 'best', None, False),
    #('shakedrop_pyramid272', 'plain_07-11-2020_14:01:08', 'best_swa', None, False),
    #('BiT-M-R152x2', 'CEDA_09-10-2020_12:41:04', 'final', None, False),
    #('BiT-S-R152x2', 'CEDA_08-10-2020_22:41:20', 'final', None, False),
]

for type, folder, checkpoint, temperature, temp in model_descriptions:


    if 'BiT' in type:
        img_size = 128
        bs = 1024 * len(device_ids)
    else:
        img_size = 32
        bs = 1024 * len(device_ids)

    model = load_model(type, folder, checkpoint, temperature, device, load_temp=temp, dataset=dataset)
    if device_ids is not None and len(device_ids) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)

    model.eval()


    print(f'{folder} - {dataset}')

    TINY_LENGTH = dl.TINY_LENGTH
    if dataset in ['cifar10', 'svhn']:
        model_outs = torch.zeros((TINY_LENGTH, 10))
    elif dataset in ['cifar100']:
        model_outs = torch.zeros((TINY_LENGTH, 100))
    else:
        raise NotImplementedError()


    #standard implementation with multithreaded rescaling
    idx = 0
    tiny_images_loader = dl.get_80MTinyImages(batch_size=bs, augm_type='none', shuffle=False, num_workers=32,
                                              size=img_size)

    pbar = tqdm(total=len(tiny_images_loader), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')
    with torch.no_grad():
        for batch, _ in tiny_images_loader:
            idx_next = idx + batch.shape[0]

            batch = batch.to(device)
            model_out_i = model(batch)
            model_outs[idx:idx_next] = model_out_i.detach().cpu()
            idx = idx_next

            pbar.update(1)

    if dataset == 'svhn' and svhn_extra:
        svhn_extra = ssl.get_SVHNValidationExtraSplit(split='extra-split', shuffle=False, batch_size=bs, augm_type='none')
        pbar = tqdm(total=len(svhn_extra),  bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')

        SVHN_LENGTH = len(svhn_extra.dataset)
        NUM_BATCHES = len(svhn_extra)
        batch_idx = 0
        start_idx = 0

        svhn_outs = torch.zeros((SVHN_LENGTH, 10))
        with torch.no_grad():
            for data_next, target in svhn_extra:
                end_idx = start_idx + data_next.shape[0]

                data_next = data_next.to(device)
                model_out_i = model(data_next)
                svhn_outs[start_idx:end_idx] = model_out_i.detach().cpu()

                start_idx = end_idx
                batch_idx += 1

                if (batch_idx % 100) == 0:
                    pbar.update(100)

        model_outs = torch.cat([model_outs, svhn_outs], dim=0)


    torch.save(model_outs, f'{path}{folder}.pt')

